library(conflicted)
conflicts_prefer(dplyr::filter)
library(BMTCensusManuscript)
library(tidyverse)
library(timetk)
library(rsample)
library(parsnip)
library(workflows)
library(workflowsets)
library(modeltime)
library(modeltime.resample)
library(poissonreg)
library(parallel)
library(future)
library(dials)
library(recipes)
library(tune)
library(yardstick)
library(kableExtra)
theme_set(theme_minimal())Reproducibility Supplement
Introduction
This document provides computer code that was used to create the tables, figures, and results which we report in our manuscript Multi-week Forecasting of Pediatric Blood and Marrow Transplant Hospital Census: Development and Real-World Validation. We provide this document to allow other interested parties to inspect, verify, and reproduce our analytic work.
The first section lists concise descriptions of the steps performed in our Analyses. To view the code (written in the R programming language) that was used to perform a specific step, as well as the results, click the black triangle to the left of the description to expand each section. The code relies heavily on the BMTCensusManuscript package developed to house code to reproduce the manuscript results. While the order of this section mirrors the order of results reported in the main manuscript, the section headings follow an analytical order laid out in the Materials and Methods section of the manuscript to allow greater depth in this document. Tables and figures from the main manuscript and supplementary materials are cross-referenced.
The second section describes the precise Computational Environment in which the computational analyses were performed. Among other parameters, this part lists the names, versions, and sources of all software packages that were used to generate this document.
Analyses
Setup
Load required libraries
Data
Load census, admissions, and scheduled admissions data
train_test_split_date <- as_date("2022-09-30")
census_data <- readRDS("manuscript/data/census.rds")
admission_data <- readRDS("manuscript/data/admissions.rds")
scheduled_admissions <- readRDS("manuscript/data/scheduled_admissions.rds") |>
count(transplant_type, date, admission_date, name = "scheduled")Plot length of stay by admission type (Figure 1a)
# Date when the data was refreshed
censoring_date <- as_date("2024-2-29")
manuscript_census_data <- census_data |>
mutate(
admission_type = case_match(
transplant_type,
"Allogeneic" ~ "Allo HSCT",
"Autologous" ~ "Auto HSCT",
"No Transplant" ~ "Other"
),
admission_type = factor(admission_type, levels = c("Auto HSCT", "Allo HSCT", "Other"), ordered = TRUE),
admission_type_malig = case_when(
admission_type == "Allo HSCT" & malignant == "Yes" ~ "Allo HSCT (Malignant)",
admission_type == "Allo HSCT" & malignant == "No" ~ "Allo HSCT (Non-Malignant)",
.default = admission_type
),
admission_type_malig = factor(admission_type_malig, levels = c("Auto HSCT", "Allo HSCT (Malignant)", "Allo HSCT (Non-Malignant)", "Other"), ordered = TRUE)
)
manuscript_encounter_data <- manuscript_census_data |>
filter(bmt_team_start_date >= "2016-01-01") |>
distinct(csn, bmt_team_start_date, bmt_team_end_date, admission_type, malignant, malignancy_history, admission_type_malig) |>
mutate(
censored = is.na(bmt_team_end_date),
los = as.numeric(coalesce(bmt_team_end_date, censoring_date) - bmt_team_start_date, unit = "days")
)
manuscript_encounter_data |>
ggplot(aes(x = los, color = admission_type)) +
geom_density() +
coord_cartesian(xlim = c(0, 80)) +
labs(
x = "Length of Stay (Days)",
y = NULL,
color = "Admission Type"
) +
theme(axis.text.y = element_blank()) +
scale_color_chop()Plot daily census by admission type (Figure 1b)
fig_1b_data <- manuscript_census_data |>
filter(time > 0, date >= "2021-01-01") |>
count(date, admission_type, name = "value") |>
complete(
date = seq(as_date("2021-01-01"), as_date("2023-06-30"), by = "day"),
admission_type,
fill = list(value = 0)
)
fig_1b_data |>
ggplot(aes(x = date, y = value, fill = admission_type)) +
geom_area(aes(group = fct_relevel(admission_type, "Allo HSCT", after = Inf))) +
labs(
fill = "Admission Type",
y = "Midnight Census",
x = NULL
) +
scale_fill_chop() +
scale_x_date(date_labels = "%b '%y")Schematic overview of census prediction model (Figure 2a)
Plot representative example of census components (Figure 2b)
fig_2b <- list()
fig_2b$date <- as_date("2023-04-01")
fig_2b$h <- 65
fig_2b$base_data <- manuscript_census_data |>
filter(time > 0, date == fig_2b$date)
fig_2b$current_admit_data <- fig_2b$base_data |>
extend_discharge_data(fig_2b$h, end_var = bmt_team_end_date) |>
mutate(date = fig_2b$date + .h) |>
summarize(value = sum(discharge == "No"), .by = date)
# Add in time 0 row
fig_2b$current_admit_data <- bind_rows(
count(fig_2b$base_data, date, name = "value"),
fig_2b$current_admit_data
)
fig_2b$new_admit_data <- admission_data |>
filter(date == fig_2b$date) |>
unnest(census_contribution) |>
filter(h <= fig_2b$h) |>
mutate(date = date + h) |>
summarize(value = sum(actual), .by = date)
# Add in time 0 row
fig_2b$new_admit_data <- bind_rows(
tibble(date = fig_2b$date, value = 0),
fig_2b$new_admit_data
)
fig_2b$data <- inner_join(
rename(fig_2b$current_admit_data, current = value),
rename(fig_2b$new_admit_data, new = value),
by = "date",
relationship = "one-to-one",
unmatched = "error"
) |>
mutate(total = current + new, x = row_number() - 1) |>
select(-date)
fig_2b$text_y <- 6
components <- list()
components$names <- c("current", "new", "total")
components$labs <- c("Current Admissions Component", "Future Admissions Component", "Total")
components$colors <- chop_colors(c("pink", "blue", "green"))
levels <- components$names |>
set_names(components$labs)
colors <- components$colors |>
set_names(components$labs)
fig_2b$data |>
pivot_longer(-x) |>
mutate(name = fct_recode(
factor(name, components$names),
!!!levels
)) |>
ggplot(aes(x = x, y = value, color = name)) +
geom_line(aes(linewidth = name)) +
geom_vline(xintercept = c(0, 30)) +
annotate("text",
y = fig_2b$text_y,
x = 30,
label = "Prediction Horizon",
angle = 90,
vjust = -.25
) +
annotate("text",
y = fig_2b$text_y,
x = 0,
label = "Current Time",
angle = 90,
vjust = -.25
) +
labs(
color = NULL,
y = "Midnight Census",
x = NULL
) +
coord_cartesian(xlim = c(as_date(NA), fig_2b$h - 5)) +
scale_linewidth_manual(values = c(Total = 1), na.value = .5, guide = "none") +
scale_color_manual(values = colors) +
theme_minimal(base_size = 14)Generate descriptive statistics of LOS and admission rates (Table 1)
table_1 <- list()
table_1$los_data <- manuscript_encounter_data |>
summarize(
median = median(los),
q1 = quantile(los, probs = .25),
q3 = quantile(los, probs = .75),
.by = admission_type_malig
)
table_1$census_share_data <- manuscript_census_data |>
filter(time > 0, date >= "2016-01-01") |>
count(date, admission_type_malig, name = "value") |>
complete(
date = seq(as_date("2016-01-01"), as_date("2023-06-30"), by = "day"),
admission_type_malig,
fill = list(value = 0)
) |>
mutate(pct = value / sum(value), .by = date) |>
summarize(pct = mean(pct), .by = admission_type_malig)
table_1$census_share_data |>
arrange(admission_type_malig) |>
left_join(table_1$los_data, by = "admission_type_malig") |>
mutate(
across(c(median, q1, q3), \(x) round(x, 1)),
q1_q3 = paste(q1, q3, sep = ", "),
pct = paste0(round(pct*100, 1), "%")
) |>
select(
`Admission Type` = admission_type_malig,
`Average % of Census` = pct,
`Median LOS (Days)` = median,
`LOS Q1, Q3 (Days)` = q1_q3
)| Admission Type | Average % of Census | Median LOS (Days) | LOS Q1, Q3 (Days) |
|---|---|---|---|
| Auto HSCT | 21.4% | 26.7 | 22.7, 34 |
| Allo HSCT (Malignant) | 25.3% | 40.9 | 35.1, 54 |
| Allo HSCT (Non-Malignant) | 29% | 37.1 | 31.3, 52.1 |
| Other | 24.3% | 3.0 | 1.5, 7.6 |
Model Selection
Split census data into train and test set
Patients may have multiple encounters. In order to prevent leakage between training and testing we remove patients from the training data that appear in the test set.
census_data_split <- census_data |>
# Don't need these extra columns
select(!any_of(c("admission_type", "admit_sched_date"))) |>
train_test_split(split_date = train_test_split_date)Split admission data into train and test set
test_length <- as.numeric(max(admission_data$date) - train_test_split_date, unit = "days")
admissions_ts_split <- admission_data |>
time_series_split(
date_var = date,
assess = test_length,
cumulative = TRUE
)discharge_train <- training(census_data_split)
admission_train <- training(admissions_ts_split)
discharge_test <- testing(census_data_split)
admission_test <- testing(admissions_ts_split)
save_data(census_data_split, "manuscript/data-split/census.rds")
save_data(admissions_ts_split, "manuscript/data-split/admissions.rds")Define candidates for discharge and admission models
The discharge model is a binary classification model predicting discharge on a given day.
The admission model is a count model predicting the number of new admissions per day by transplant_type.
Discharge: - Logistic Regression - XGBoost
Admission: - Poisson Regression - XGBoost
discharge_models <- list(
logistic = logistic_reg(),
xgboost = boost_tree(mode = "classification")
)
discharge_wflowset <- workflow_set(
preproc = list(baseline = discharge_rec(discharge_train)),
models = discharge_models
)
admission_wflow_poisson <- workflow(
preprocessor = admission_rec(admission_train, type = "poisson"),
spec = poisson_reg("regression")
)
admission_wflow_xgb <- workflow(
preprocessor = admission_rec(admission_train, type = "xgb"),
spec = boost_tree("regression") |>
set_engine("xgboost", objective = "count:poisson")
)
admission_model_tbl <- modeltime_table(
fit(admission_wflow_poisson, admission_train),
fit(admission_wflow_xgb, admission_train)
)
admission_model_tbl$.model_id <- c("poisson", "xgboost")Create folds for cross-validation
A single fold consists of a discharge fold and an admission fold aligned by prediction date, the first date in the test split of the fold.
Each discharge fold contains:
- A test split with the census on the prediction date
- A train split with the census for all days before the prediction date
Each admission fold contains:
- A test split with admissions data for the prediction date and the 90 days after
- A train split admission data for all days before the prediction date
splits <- make_cv_splits(
discharge_train,
admission_train,
h = 90,
scheduled_admissions = scheduled_admissions
)Run cross-validation to compare candidate models
discharge_results_cv <- fit_discharge_resamples(discharge_wflowset, splits)
admission_results_cv <- fit_admission_resamples(admission_model_tbl, splits$admission)
cv_results <- collect_cv_results(admission_results_cv, discharge_results_cv, splits)
save_data(cv_results, "manuscript/results/cv_results.rds")Compute RMSE for model selection results
cv_results_summary <- cv_results |>
select(admission_model, discharge_model, id, results) |>
unnest(results) |>
filter(.h %in% c(15, 30, 60, 90)) |>
pivot_longer(
matches("actual|pred"),
names_to = c(".value", "stat"),
names_pattern = "\\.(actual|pred)_(.+)"
) |>
group_by(admission_model, discharge_model, stat, .h) |>
rmse(actual, pred) |>
filter(stat == "census")Generate line plot of model RMSE over prediction horizon
cv_results_summary |>
ggplot(aes(color = paste(admission_model, discharge_model, sep = "; "), y = .estimate, x = factor(.h))) +
geom_point() +
geom_line(aes(group = paste(admission_model, discharge_model))) +
labs(y = "RMSE", x = "h (days ahead)", color = NULL)Model Tuning
Having selected the XGBoost model for both the admission and discharge models we now jointly tune all hyperparameters using grid search.
Create tuning grid for discharge model
discharge_model_tune <- boost_tree(
mode = "classification",
sample_size = tune(),
learn_rate = tune(),
trees = 1000,
min_n = tune(),
mtry = tune()
)
params <- discharge_model_tune |>
extract_parameter_set_dials() |>
finalize(bake(prep(discharge_rec(discharge_train)), new_data = NULL))
discharge_wflowset_tune <- workflow_set(
preproc = list(baseline = discharge_rec(discharge_train)),
models = list(xgboost = discharge_model_tune)
) |>
option_add(grid = 20, param_info = params)Fit models on discharge tuning grid
plan("multisession")
discharge_results_tune <- fit_discharge_resamples(discharge_wflowset_tune, splits)
plan("sequential")Create tuning grid for admission model
admission_model_tune <- boost_tree(
mode = "regression",
trees = 1000,
sample_size = tune(),
learn_rate = tune(),
min_n = tune(),
mtry = tune()
) |>
set_engine("xgboost", objective = "count:poisson")
admission_grid <- admission_model_tune |>
extract_parameter_set_dials() |>
finalize(juice(prep(admission_rec(admission_train, type = "xgb")))) |>
grid_space_filling(size = 20)
admission_model_grid <- create_model_grid(
grid = admission_grid,
f_model_spec = boost_tree,
engine_name = "xgboost",
mode = "regression",
trees = 1000,
engine_params = list(objective = "count:poisson")
)Fit models on admission tuning grid
plan("multisession")
admission_model_tbl_tune <- workflow_set(
preproc = list(admission_rec(admission_train, type = "xgb")),
models = admission_model_grid$.models
) |>
modeltime_fit_workflowset(data = admission_train, control = control_fit_workflowset(allow_par = TRUE))
admission_model_tbl_tune$.model_id <- "xgboost"
admission_model_tbl_tune$.config <- 1:nrow(admission_model_tbl_tune)
admission_model_tbl_tune$hyperparams <- admission_model_grid |>
select(-.models) |>
pmap(tibble)
admission_results_tune <- fit_admission_resamples(admission_model_tbl_tune, splits$admission)
plan("sequential")Generate results for joint discharge-admission tuning grid
plan("multisession")
tune_results <- collect_cv_results(admission_results_tune, discharge_results_tune, splits)
plan("sequential")
save_data(tune_results, "manuscript/results/tune_results.rds")Plot ranked MAPE for top 30 model configuration
mset <- metric_set(rmse, mape)
tune_results_summary <- tune_results |>
rename(discharge = discharge_hyperparams, admission = admission_hyperparams) |>
unnest(c(discharge, admission), names_sep = "_") |>
unnest(results) |>
filter(.h %in% c(15, 30, 60, 90)) |>
pivot_longer(
matches("actual|pred"),
names_to = c(".value", "stat"),
names_pattern = "\\.(actual|pred)_(.+)"
) |>
group_by(admission_config, discharge_config, stat, .h, pick(matches("mtry|min_n|learn_rate|sample_size$"))) |>
mset(actual, pred) |>
filter(stat == "census")tune_results_summary |>
filter(.metric == "mape") |>
mutate(avg_estimate = mean(.estimate), .by = c(admission_config, discharge_config)) |>
mutate(rank = rank(.estimate), .by = .h) |>
mutate(avg_rank = dense_rank(avg_estimate)) |>
filter(avg_rank <= 30) |>
ggplot(aes(y = fct_rev(factor(avg_rank)), x = rank)) +
geom_jitter(aes(color = factor(.h)), height = 0)Plot ranked RMSE for top 30 model configuration
tune_results_summary |>
filter(.metric == "rmse") |>
mutate(avg_estimate = mean(.estimate), .by = c(admission_config, discharge_config)) |>
mutate(rank = rank(.estimate), .by = .h) |>
mutate(avg_rank = dense_rank(avg_estimate)) |>
filter(avg_rank <= 30) |>
ggplot(aes(y = fct_rev(factor(avg_rank)), x = rank)) +
geom_jitter(aes(color = factor(.h)), height = 0)Select model with minimal RMSE
best_config <- tune_results_summary |>
filter(.metric == "rmse") |>
mutate(avg_estimate = mean(.estimate), .by = c(admission_config, discharge_config)) |>
filter(dense_rank(avg_estimate) == 1) |>
distinct(admission_config, discharge_config)
final_params <- tune_results |>
filter(
discharge_config == best_config$discharge_config,
admission_config == best_config$admission_config
) |>
distinct(admission_hyperparams, discharge_hyperparams) |>
as.list() |>
map(pluck(1))
final_params$admission_hyperparams| mtry | min_n | learn_rate | sample_size |
|---|---|---|---|
| 3 | 9 | 0.0036468 | 0.9712415 |
final_params$discharge_hyperparams| mtry | min_n | learn_rate | sample_size |
|---|---|---|---|
| 6 | 28 | 0.005141 | 0.1383986 |
Model Evaluation
Fit selected model on full training data
discharge_model_train <- workflow(
discharge_rec(discharge_train),
discharge_model_tune
) |>
finalize_workflow(final_params$discharge_hyperparams) |>
fit(discharge_train)
admission_model_train <- workflow(
admission_rec(admission_train, type = "xgb"),
admission_model_tune
) |>
finalize_workflow(final_params$admission_hyperparams) |>
fit(admission_train) |>
modeltime_table()
models_train <- bmt_model(discharge_model = discharge_model_train, admission_model = admission_model_train)
save_data(models_train, "manuscript/models/models_train.rds")Predict on held-out test data
test_dates <- unique(admission_test$date)
test_results <- make_test_data(test_dates, discharge_test, admission_test, scheduled_admissions)
test_results$results <- pmap(test_results, \(current_admissions, scheduled_admissions, ...) {
predict_census(
models = models_train,
h = 1:90,
current_census = current_admissions,
scheduled_admissions = scheduled_admissions
)
})
save_data(test_results, "manuscript/results/test_results.rds")Prospective validation data are the results of an automated process that made daily, prospective predictions of the census.
Load prospective validation data
validation_results <- readRDS("manuscript/data/validation_results.rds")
withr::with_seed(71919, {
validation_results_bootstrap <- validation_results |>
filter(.h %in% c(15, 30, 60), !is.na(.actual_census)) |>
generate_bootstraps(bootstrap_block_size = 7)
})Calculate evaluation metrics for test set and prospective validation (Table 2)
manuscript_test_results <- left_join(
test_results |>
select(test_date, results) |>
unnest(results),
test_results |>
select(test_date, actual) |>
unnest(actual),
by = c("test_date", ".h")
) |>
filter(.h <= 60)
table_2 <- list()
table_2$test <- make_metrics_table(manuscript_test_results) |>
filter(Component == "Total Census") |>
select(-Component)
table_2$validation <- make_metrics_table(validation_results) |>
filter(Component == "Total Census") |>
select(-Component)
table_2$thresholds <- tibble(
Horizon = c(15, 30, 60),
`Prospective Validation MAPE Threshold` = c(15, 20, 25),
`Prospective Validation Result` = "Pass"
)
table_2$boot <- summarize_bootstrap(validation_results_bootstrap)$table |>
filter(.metric == "mape") |>
select(Horizon = .h, `Prospective Validation MAPE 95% Bound` = `0.95`)
left_join(
rename_with(table_2$test, \(x) paste("Test Set", x), c(RMSE, MAPE)),
rename_with(table_2$validation, \(x) paste("Prospective Validation", x), c(RMSE, MAPE)),
by = "Horizon"
) |>
left_join(table_2$boot, by = "Horizon") |>
left_join(table_2$thresholds, by = "Horizon") |>
mutate(
across(ends_with("RMSE"), \(x) round(x, 1)),
across(ends_with("MAPE"), \(x) round(x, 1)),
across(ends_with("MAPE 95% Bound"), \(x) round(x, 1))
) |>
kbl(
col.names = c(
"Prediction Horizon (Days)",
"RMSE",
"MAPE",
"RMSE",
"MAPE",
"MAPE 95% Upper Bound",
"MAPE 95% Upper Bound Threshold",
"Result"
),
align = "c"
) |>
add_header_above(c(" " = 1, "Held-out Test Set" = 2, "Prospective Validation" = 5), bold = TRUE)| Prediction Horizon (Days) | RMSE | MAPE | RMSE | MAPE | MAPE 95% Upper Bound | MAPE 95% Upper Bound Threshold | Result |
|---|---|---|---|---|---|---|---|
| 15 | 2.5 | 13.2 | 2.5 | 12.4 | 12.3 | 15 | Pass |
| 30 | 3.1 | 16.4 | 1.7 | 8.1 | 7.8 | 20 | Pass |
| 60 | 3.6 | 18.2 | 2.5 | 11.2 | 12.2 | 25 | Pass |
Plot total census predictions for the test set (Figure 3a)
manuscript_test_results |>
plot_results(.pred_census, .actual_census)Plot total census predictions for the prospective validation (Figure 3b)
validation_results |>
plot_results(.pred_census, .actual_census) +
scale_y_continuous(breaks = c(12, 15, 18, 21))Plot sample prediction curves for current admissions component on the test set (Figure S1a)
manuscript_test_results |>
plot_prediction_curves(
pred = .pred_current_admission,
actual = .actual_current_admission,
group = test_date,
n_show = 10
)Plot sample prediction curves for current admissions component on the prospective validation (Figure S1b)
validation_results |>
filter(max(.h) == 60, .by = prediction_date) |>
plot_prediction_curves(
pred = .pred_current_admission,
actual = .actual_current_admission,
group = prediction_date,
n_show = 6
)Plot sample prediction curves for future admissions component on the test set (Figure S2a)
manuscript_test_results |>
plot_prediction_curves(
pred = .pred_new_admission,
actual = .actual_new_admission,
group = test_date,
n_show = 10
)Plot sample prediction curves for future admissions component on the prospective validation (Figure S2b)
validation_results |>
filter(max(.h) == 60, .by = prediction_date) |>
plot_prediction_curves(
pred = .pred_new_admission,
actual = .actual_new_admission,
group = prediction_date,
n_show = 6
)Fit model on full dataset
discharge_model_final <- fit(models_train$discharge, census_data)
admission_model_final <- modeltime_refit(models_train$admission, admission_data)
models_final <- bmt_model(discharge_model = discharge_model_final, admission_model = admission_model_final)
save_data(models_final, "manuscript/models/models_final.rds")Supplementary Analyses
Create total census time series
census_ts_data <- census_data |>
filter(time > 0, date >= "2016-01-01") |>
count(date, name = "value")
census_ts_split <- census_ts_data |>
time_series_split(
date_var = date,
assess = test_length,
cumulative = TRUE
)
census_ts_train <- training(census_ts_split)
reframe(census_ts_train, range(date))| range(date) |
|---|
| 2016-01-01 |
| 2022-09-30 |
census_ts_test <- testing(census_ts_split)
reframe(census_ts_test, range(date))| range(date) |
|---|
| 2022-10-01 |
| 2023-06-30 |
Predict test set using heuristic models
post_process_naive_results <- function(results, h) {
out <- select(results, .model_desc)
out$preds <- results$.resample_results |>
map(\(x) {
x |>
select(id, .predictions) |>
unnest(.predictions) |>
mutate(.h = row_number(), .by = id) |>
slice(h, .by = id)
})
out |> unnest(preds)
}
# Include prior year of data for seasonal naive model
census_ts_test_leadin <- bind_rows(
census_ts_train |>
filter(date > train_test_split_date - 365),
census_ts_test
)
naive_models <- modeltime_table(
naive_reg() |>
set_engine("naive") |>
fit(value ~ date, census_ts_train),
naive_reg(seasonal_period = 12) |>
set_engine("snaive") |>
fit(value ~ date, census_ts_train)
)
census_ts_test_leadin_split <- time_series_cv(
census_ts_test_leadin,
date_var = date,
assess = "60 days",
initial = 365,
cumulative = TRUE
)
naive_results <- naive_models |>
modeltime_fit_resamples(census_ts_test_leadin_split) |>
post_process_naive_results(h = c(15, 30, 60))
mean_results <- tibble(
id = census_ts_test_leadin_split$id,
.pred = map_dbl(census_ts_test_leadin_split$splits, \(x) mean(training(x)$value))
)
heuristic_results <- bind_rows(
naive_results,
naive_results |>
select(-.pred) |>
mutate(.model_desc = "MEAN") |>
left_join(mean_results, by = "id", relationship = "many-to-one", unmatched = "error")
)
save_data(heuristic_results, "manuscript/results/heuristic_results.rds")Fit simple time series models
ts_wflow_poisson <- workflow(
preprocessor = admission_rec(census_ts_train, type = "poisson"),
spec = poisson_reg("regression")
)
ts_wflow_xgb <- workflow(
preprocessor = admission_rec(census_ts_train, type = "xgb"),
spec = boost_tree(
"regression",
trees = 1000
) |>
set_engine("xgboost", objective = "count:poisson")
)
ts_models <- modeltime_table(
fit(ts_wflow_poisson, census_ts_train),
fit(ts_wflow_xgb, census_ts_train)
) |>
update_modeltime_description(1, "poisson") |>
update_modeltime_description(2, "xgboost")Create tuning grid for simple time series models
ts_xgb_tune_split <- make_census_ts_splits(census_ts_train, 60, slice_limit = 15, skip = "3 months")
ts_xgb_tune <- boost_tree(
mode = "regression",
trees = 1000,
sample_size = tune(),
learn_rate = tune(),
min_n = tune(),
mtry = tune()
) |>
set_engine("xgboost", objective = "count:poisson")
ts_xgb_grid <- ts_xgb_tune |>
extract_parameter_set_dials() |>
update(learn_rate = learn_rate(range = c(-2.3, -1.6))) |>
finalize(juice(prep(admission_rec(census_ts_train, type = "xgb")))) |>
grid_space_filling(size = 20)
ts_xgb_model_grid <- create_model_grid(
grid = ts_xgb_grid,
f_model_spec = boost_tree,
engine_name = "xgboost",
mode = "regression",
trees = 1000,
engine_params = list(objective = "count:poisson")
)Fit models on time series tuning grid
plan("multisession")
ts_xgb_model_tbl_tune <- workflow_set(
preproc = list(admission_rec(census_ts_train, type = "xgb")),
models = ts_xgb_model_grid$.models
) |>
modeltime_fit_workflowset(data = census_ts_train, control = control_fit_workflowset(allow_par = TRUE))
ts_xgb_model_tbl_tune$.config <- 1:nrow(ts_xgb_model_tbl_tune)
ts_xgb_model_tbl_tune$hyperparams <- ts_xgb_model_grid |>
select(-.models) |>
pmap(tibble)
ts_tune_results <- modeltime_fit_resamples(
ts_xgb_model_tbl_tune,
resamples = ts_xgb_tune_split
) |>
postprocess_ts_resamples(pred_name = ".pred")
plan("sequential")
save_data(ts_tune_results, "manuscript/results/ts_tune_results.rds")Plot RMSE and MAPE for all configurations in tuning grid
ts_tune_results_summary <- ts_tune_results |>
unnest(hyperparams) |>
unnest(results) |>
filter(.h %in% c(15, 30, 60)) |>
group_by(.model_id, .h, .config, pick(matches("mtry|min_n|learn_rate|sample_size$"))) |>
mset(value, .pred)
ts_tune_results_summary |>
ggplot(aes(x = .estimate, y = factor(.config), color = factor(.h))) +
geom_point() +
facet_wrap(~.metric, scales = "free_x")Select time series model with minimal RMSE
ts_best_config <- ts_tune_results_summary |>
filter(.metric == "rmse") |>
mutate(avg_estimate = mean(.estimate), .by = .config) |>
filter(dense_rank(avg_estimate) == 1) |>
distinct(.config)
ts_final_params <- ts_tune_results |>
filter(
.config == ts_best_config$.config
) |>
distinct(hyperparams) |>
pull(hyperparams) |>
pluck(1)
ts_final_params| mtry | min_n | learn_rate | sample_size |
|---|---|---|---|
| 1 | 14 | 0.0164357 | 0.4315789 |
ts_xgb_tuned <- workflow(
admission_rec(census_ts_train, type = "xgb"),
ts_xgb_tune
) |>
finalize_workflow(ts_final_params) |>
fit(census_ts_train) |>
modeltime_table()Predict on test set with all time series models
# Add tuned xgb to models
ts_models <- combine_modeltime_tables(
ts_models,
ts_xgb_tuned
) |>
update_modeltime_description(3, "xgboost tuned")
forecast_test_date <- function(test_date, test_data, models, h = 1:60) {
max_h <- max(h)
new_data <- tibble(
date = seq(test_date, test_date + max_h - 1, by = "day"),
.h = 1:max_h
) |>
left_join(test_data, by = "date")
models |>
modeltime_calibrate(new_data = new_data) |>
modeltime_forecast(
new_data,
keep_data = TRUE
) |>
filter(.h %in% h) |>
select(all_of(colnames(new_data)), .value, .model_id, .model_desc)
}
# Sample 40 test dates for prediction
ts_test_dates <- census_ts_test |>
filter(!is.na(lead(value, 60))) |>
pull(date) |>
unique() |>
sample(40)
ts_test_results <- tibble(test_date = ts_test_dates)
ts_test_results$results <- map(
ts_test_results$test_date,
forecast_test_date,
test_data = census_ts_test,
models = ts_models
)
saveRDS(ts_test_results, "manuscript/results/ts_test_results.rds")Create summary metrics table for heuristic and time series models (Table S1)
heuristic_results_summary <- heuristic_results |>
group_by(.model_desc, .h) |>
mset(value, .pred)
ts_results_summary <- ts_test_results |>
unnest(results) |>
filter(.h %in% c(15, 30, 60), !is.na(value)) |>
group_by(.model_desc, .h) |>
mset(value, .value)
bind_rows(
heuristic_results_summary,
ts_results_summary |>
filter(.model_desc %in% c("poisson", "xgboost tuned")),
table_2$test |>
select(.h = Horizon, mape = MAPE, rmse = RMSE) |>
pivot_longer(c(mape, rmse), names_to = ".metric", values_to = ".estimate") |>
mutate(.model_desc = "Preferred")
) |>
mutate(
.model_desc = case_match(
.model_desc,
"MEAN" ~ "Mean",
"SNAIVE [12]" ~ "Seasonal Naive",
"NAIVE" ~ "Naive",
"poisson" ~ "Poisson",
"xgboost tuned" ~ "GBDT",
.default = .model_desc
),
model_type = case_match(
.model_desc,
c("Mean", "Seasonal Naive", "Naive") ~ "Heuristic Model",
c("Poisson", "GBDT") ~ "Time Series Model",
"Preferred" ~ "Component-based Model"
),
.estimate = round(.estimate, 1)
) |>
pivot_wider(id_cols = c(.h, model_type, .model_desc), names_from = .metric, values_from = .estimate) |>
kbl(
col.names = c(
"Prediction Horizon (Days)",
"Model Type",
"Model",
"RMSE",
"MAPE"
),
align = "c"
) |>
row_spec(16:18, bold = TRUE)| Prediction Horizon (Days) | Model Type | Model | RMSE | MAPE |
|---|---|---|---|---|
| 15 | Heuristic Model | Mean | 4.4 | 30.9 |
| 30 | Heuristic Model | Mean | 4.5 | 31.6 |
| 60 | Heuristic Model | Mean | 4.6 | 31.9 |
| 15 | Heuristic Model | Naive | 3.6 | 19.7 |
| 30 | Heuristic Model | Naive | 4.3 | 24.6 |
| 60 | Heuristic Model | Naive | 5.6 | 30.6 |
| 15 | Heuristic Model | Seasonal Naive | 4.0 | 21.9 |
| 30 | Heuristic Model | Seasonal Naive | 4.6 | 27.1 |
| 60 | Heuristic Model | Seasonal Naive | 5.6 | 30.6 |
| 15 | Time Series Model | Poisson | 3.1 | 20.9 |
| 30 | Time Series Model | Poisson | 3.1 | 20.4 |
| 60 | Time Series Model | Poisson | 3.7 | 23.1 |
| 15 | Time Series Model | GBDT | 3.2 | 21.6 |
| 30 | Time Series Model | GBDT | 3.1 | 20.6 |
| 60 | Time Series Model | GBDT | 3.9 | 23.2 |
| 15 | Component-based Model | Preferred | 2.5 | 13.2 |
| 30 | Component-based Model | Preferred | 3.1 | 16.4 |
| 60 | Component-based Model | Preferred | 3.6 | 18.2 |
Computational Environment
Computational Environment
sessioninfo::session_info()─ Session info ───────────────────────────────────────────────────────────────
setting value
version R version 4.4.3 (2025-02-28)
os macOS Sequoia 15.7.3
system aarch64, darwin20
ui X11
language (EN)
collate en_US.UTF-8
ctype en_US.UTF-8
tz America/New_York
date 2025-12-29
pandoc 3.6.3 @ /Applications/Positron.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown)
quarto 1.7.32 @ /Applications/quarto/bin/quarto
─ Packages ───────────────────────────────────────────────────────────────────
! package * version date (UTC) lib source
P anytime 0.3.12 2025-07-14 [?] CRAN (R 4.4.1)
P backports 1.5.0 2024-05-23 [?] CRAN (R 4.4.0)
BMTCensusManuscript * 0.0.1.0 2025-12-29 [1] local
P broom 1.0.10 2025-09-13 [?] CRAN (R 4.4.1)
P cachem 1.1.0 2024-05-16 [?] CRAN (R 4.4.0)
P class 7.3-23 2025-01-01 [?] CRAN (R 4.4.3)
P cli 3.6.5 2025-04-23 [?] CRAN (R 4.4.1)
P codetools 0.2-20 2024-03-31 [?] CRAN (R 4.4.3)
P conflicted * 1.2.0 2023-02-01 [?] CRAN (R 4.4.0)
P data.table 1.17.8 2025-07-10 [?] CRAN (R 4.4.1)
P dials * 1.4.2 2025-09-04 [?] CRAN (R 4.4.1)
P DiceDesign 1.10 2023-12-07 [?] CRAN (R 4.4.0)
P digest 0.6.38 2025-11-12 [?] CRAN (R 4.4.1)
P distributional 0.5.0 2024-09-17 [?] CRAN (R 4.4.1)
P dplyr * 1.1.4 2023-11-17 [?] CRAN (R 4.4.0)
P ellipsis 0.3.2 2021-04-29 [?] CRAN (R 4.4.0)
P evaluate 1.0.5 2025-08-27 [?] CRAN (R 4.4.1)
P fabletools 0.5.1 2025-09-01 [?] CRAN (R 4.4.1)
P farver 2.1.2 2024-05-13 [?] CRAN (R 4.4.0)
P fastmap 1.2.0 2024-05-15 [?] CRAN (R 4.4.0)
P feasts 0.4.2 2025-08-27 [?] CRAN (R 4.4.1)
P forcats * 1.0.1 2025-09-25 [?] CRAN (R 4.4.1)
P furrr 0.3.1 2022-08-15 [?] CRAN (R 4.4.0)
P future * 1.67.0 2025-07-29 [?] CRAN (R 4.4.1)
P future.apply 1.20.0 2025-06-06 [?] CRAN (R 4.4.1)
P generics 0.1.4 2025-05-09 [?] CRAN (R 4.4.1)
P ggplot2 * 4.0.0 2025-09-11 [?] CRAN (R 4.4.1)
P globals 0.18.0 2025-05-08 [?] CRAN (R 4.4.1)
P glue 1.8.0 2024-09-30 [?] CRAN (R 4.4.1)
P gower 1.0.2 2024-12-17 [?] CRAN (R 4.4.1)
P GPfit 1.0-9 2025-04-12 [?] CRAN (R 4.4.1)
P gtable 0.3.6 2024-10-25 [?] CRAN (R 4.4.1)
P hardhat 1.4.2 2025-08-20 [?] CRAN (R 4.4.1)
P hms 1.1.4 2025-10-17 [?] CRAN (R 4.4.1)
P htmltools 0.5.8.1 2024-04-04 [?] CRAN (R 4.4.0)
P htmlwidgets 1.6.4 2023-12-06 [?] CRAN (R 4.4.0)
P infer 1.0.9 2025-06-26 [?] CRAN (R 4.4.1)
P ipred 0.9-15 2024-07-18 [?] CRAN (R 4.4.0)
P jsonlite 2.0.0 2025-03-27 [?] CRAN (R 4.4.1)
P kableExtra * 1.4.0 2024-01-24 [?] CRAN (R 4.4.0)
P knitr 1.50 2025-03-16 [?] CRAN (R 4.4.1)
P labeling 0.4.3 2023-08-29 [?] CRAN (R 4.4.0)
P lattice 0.22-6 2024-03-20 [?] CRAN (R 4.4.3)
P lava 1.8.2 2025-10-30 [?] CRAN (R 4.4.1)
P lhs 1.2.0 2024-06-30 [?] CRAN (R 4.4.1)
P lifecycle 1.0.4 2023-11-07 [?] CRAN (R 4.4.0)
P listenv 0.10.0 2025-11-02 [?] CRAN (R 4.4.1)
P lubridate * 1.9.4 2024-12-08 [?] CRAN (R 4.4.1)
P magrittr 2.0.4 2025-09-12 [?] CRAN (R 4.4.1)
P MASS 7.3-64 2025-01-04 [?] CRAN (R 4.4.3)
P Matrix 1.7-2 2025-01-23 [?] CRAN (R 4.4.3)
P memoise 2.0.1 2021-11-26 [?] CRAN (R 4.4.0)
P modeldata 1.5.1 2025-08-22 [?] CRAN (R 4.4.1)
P modeltime * 1.3.2 2025-08-28 [?] CRAN (R 4.4.1)
P modeltime.resample * 0.2.3 2023-04-12 [?] CRAN (R 4.4.0)
P nnet 7.3-20 2025-01-01 [?] CRAN (R 4.4.3)
P parallelly 1.45.1 2025-07-24 [?] CRAN (R 4.4.1)
P parsnip * 1.3.3 2025-08-31 [?] CRAN (R 4.4.1)
P pillar 1.11.1 2025-09-17 [?] CRAN (R 4.4.1)
P pkgconfig 2.0.3 2019-09-22 [?] CRAN (R 4.4.0)
P poissonreg * 1.0.1 2022-08-22 [?] CRAN (R 4.4.0)
P prodlim 2025.04.28 2025-04-28 [?] CRAN (R 4.4.1)
P progressr 0.18.0 2025-11-06 [?] CRAN (R 4.4.1)
P purrr * 1.2.0 2025-11-04 [?] CRAN (R 4.4.3)
P R6 2.6.1 2025-02-15 [?] CRAN (R 4.4.1)
P RColorBrewer 1.1-3 2022-04-03 [?] CRAN (R 4.4.0)
P Rcpp 1.1.0 2025-07-02 [?] CRAN (R 4.4.1)
P RcppParallel 5.1.11-1 2025-08-27 [?] CRAN (R 4.4.1)
P readr * 2.1.5 2024-01-10 [?] CRAN (R 4.4.0)
P recipes * 1.3.1 2025-05-21 [?] CRAN (R 4.4.1)
renv 1.1.5 2025-07-24 [1] CRAN (R 4.4.1)
P rlang 1.1.6 2025-04-11 [?] CRAN (R 4.4.1)
P rmarkdown 2.30 2025-09-28 [?] CRAN (R 4.4.1)
P rpart 4.1.24 2025-01-07 [?] CRAN (R 4.4.3)
P rsample * 1.3.1 2025-07-29 [?] CRAN (R 4.4.1)
P rstudioapi 0.17.1 2024-10-22 [?] CRAN (R 4.4.1)
P S7 0.2.0 2024-11-07 [?] CRAN (R 4.4.1)
P scales * 1.4.0 2025-04-24 [?] CRAN (R 4.4.0)
P sessioninfo 1.2.3 2025-02-05 [?] CRAN (R 4.4.1)
P sfd 0.1.0 2024-01-08 [?] CRAN (R 4.4.0)
P sparsevctrs 0.3.4 2025-05-25 [?] CRAN (R 4.4.1)
P StanHeaders 2.32.10 2024-07-15 [?] CRAN (R 4.4.0)
P stringi 1.8.7 2025-03-27 [?] CRAN (R 4.4.1)
P stringr * 1.6.0 2025-11-04 [?] CRAN (R 4.4.3)
P survival 3.8-3 2024-12-17 [?] CRAN (R 4.4.3)
P svglite 2.2.2 2025-10-21 [?] CRAN (R 4.4.1)
P systemfonts 1.3.1 2025-10-01 [?] CRAN (R 4.4.1)
P tailor 0.1.0 2025-08-25 [?] CRAN (R 4.4.1)
P textshaping 1.0.4 2025-10-10 [?] CRAN (R 4.4.1)
P tibble * 3.3.0 2025-06-08 [?] CRAN (R 4.4.1)
P tidymodels 1.4.1 2025-09-08 [?] CRAN (R 4.4.1)
P tidyr * 1.3.1 2024-01-24 [?] CRAN (R 4.4.0)
P tidyselect 1.2.1 2024-03-11 [?] CRAN (R 4.4.0)
P tidyverse * 2.0.0 2023-02-22 [?] CRAN (R 4.4.0)
P timechange 0.3.0 2024-01-18 [?] CRAN (R 4.4.0)
P timeDate 4051.111 2025-10-17 [?] CRAN (R 4.4.1)
P timetk * 2.9.1 2025-08-29 [?] CRAN (R 4.4.1)
P tsibble 1.1.6 2025-01-30 [?] CRAN (R 4.4.1)
P tune * 2.0.1 2025-10-17 [?] CRAN (R 4.4.1)
P tzdb 0.5.0 2025-03-15 [?] CRAN (R 4.4.1)
P vctrs 0.6.5 2023-12-01 [?] CRAN (R 4.4.0)
P viridisLite 0.4.2 2023-05-02 [?] CRAN (R 4.4.0)
P withr 3.0.2 2024-10-28 [?] CRAN (R 4.4.1)
P workflows * 1.3.0 2025-08-27 [?] CRAN (R 4.4.1)
P workflowsets * 1.1.1 2025-05-27 [?] CRAN (R 4.4.1)
P xfun 0.54 2025-10-30 [?] CRAN (R 4.4.1)
P xgboost 1.7.11.1 2025-05-15 [?] CRAN (R 4.4.1)
P xml2 1.4.1 2025-10-27 [?] CRAN (R 4.4.1)
P xts 0.14.1 2024-10-15 [?] CRAN (R 4.4.1)
P yaml 2.3.10 2024-07-26 [?] CRAN (R 4.4.0)
P yardstick * 1.3.2 2025-01-22 [?] CRAN (R 4.4.1)
P zoo 1.8-14 2025-04-10 [?] CRAN (R 4.4.1)
[1] /Users/porterej/Library/Caches/org.R-project.R/R/renv/library/BMTCensusManuscript-80fd98b5/macos/R-4.4/aarch64-apple-darwin20
[2] /Users/porterej/Library/Caches/org.R-project.R/R/renv/sandbox/macos/R-4.4/aarch64-apple-darwin20/f7156815
* ── Packages attached to the search path.
P ── Loaded and on-disk path mismatch.
──────────────────────────────────────────────────────────────────────────────